{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# V-JEPA 2 Demo Notebook\n",
        "\n",
        "This tutorial provides an example of how to load the V-JEPA 2 model in vanilla PyTorch and HuggingFace, extract a video embedding, and then predict an action class. For more details about the paper and model weights, please see https://github.com/facebookresearch/vjepa2."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "First, let's import the necessary libraries and load the necessary functions for this tutorial."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "import json\n",
        "import os\n",
        "import subprocess\n",
        "\n",
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn.functional as F\n",
        "from decord import VideoReader\n",
        "from transformers import AutoVideoProcessor, AutoModel\n",
        "\n",
        "import src.datasets.utils.video.transforms as video_transforms\n",
        "import src.datasets.utils.video.volume_transforms as volume_transforms\n",
        "from src.models.attentive_pooler import AttentiveClassifier\n",
        "from src.models.vision_transformer import vit_giant_xformers_rope\n",
        "\n",
        "IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)\n",
        "IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)\n",
        "\n",
        "def load_pretrained_vjepa_pt_weights(model, pretrained_weights):\n",
        "    # Load weights of the VJEPA2 encoder\n",
        "    # The PyTorch state_dict is already preprocessed to have the right key names\n",
        "    pretrained_dict = torch.load(pretrained_weights, weights_only=True, map_location=\"cpu\")[\"encoder\"]\n",
        "    pretrained_dict = {k.replace(\"module.\", \"\"): v for k, v in pretrained_dict.items()}\n",
        "    pretrained_dict = {k.replace(\"backbone.\", \"\"): v for k, v in pretrained_dict.items()}\n",
        "    msg = model.load_state_dict(pretrained_dict, strict=False)\n",
        "    print(\"Pretrained weights found at {} and loaded with msg: {}\".format(pretrained_weights, msg))\n",
        "\n",
        "\n",
        "def load_pretrained_vjepa_classifier_weights(model, pretrained_weights):\n",
        "    # Load weights of the VJEPA2 classifier\n",
        "    # The PyTorch state_dict is already preprocessed to have the right key names\n",
        "    pretrained_dict = torch.load(pretrained_weights, weights_only=True, map_location=\"cpu\")[\"classifiers\"][0]\n",
        "    pretrained_dict = {k.replace(\"module.\", \"\"): v for k, v in pretrained_dict.items()}\n",
        "    msg = model.load_state_dict(pretrained_dict, strict=False)\n",
        "    print(\"Pretrained weights found at {} and loaded with msg: {}\".format(pretrained_weights, msg))\n",
        "\n",
        "\n",
        "def build_pt_video_transform(img_size):\n",
        "    short_side_size = int(256.0 / 224 * img_size)\n",
        "    # Eval transform has no random cropping nor flip\n",
        "    eval_transform = video_transforms.Compose(\n",
        "        [\n",
        "            video_transforms.Resize(short_side_size, interpolation=\"bilinear\"),\n",
        "            video_transforms.CenterCrop(size=(img_size, img_size)),\n",
        "            volume_transforms.ClipToTensor(),\n",
        "            video_transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),\n",
        "        ]\n",
        "    )\n",
        "    return eval_transform\n",
        "\n",
        "\n",
        "def get_video():\n",
        "    vr = VideoReader(\"sample_video.mp4\")\n",
        "    # choosing some frames here, you can define more complex sampling strategy\n",
        "    frame_idx = np.arange(0, 128, 2)\n",
        "    video = vr.get_batch(frame_idx).asnumpy()\n",
        "    return video\n",
        "\n",
        "\n",
        "def forward_vjepa_video(model_hf, model_pt, hf_transform, pt_transform):\n",
        "    # Run a sample inference with VJEPA\n",
        "    with torch.inference_mode():\n",
        "        # Read and pre-process the image\n",
        "        video = get_video()  # T x H x W x C\n",
        "        video = torch.from_numpy(video).permute(0, 3, 1, 2)  # T x C x H x W\n",
        "        x_pt = pt_transform(video).cuda().unsqueeze(0)\n",
        "        x_hf = hf_transform(video, return_tensors=\"pt\")[\"pixel_values_videos\"].to(\"cuda\")\n",
        "        # Extract the patch-wise features from the last layer\n",
        "        out_patch_features_pt = model_pt(x_pt)\n",
        "        out_patch_features_hf = model_hf.get_vision_features(x_hf)\n",
        "\n",
        "    return out_patch_features_hf, out_patch_features_pt\n",
        "\n",
        "\n",
        "def get_vjepa_video_classification_results(classifier, out_patch_features_pt):\n",
        "    SOMETHING_SOMETHING_V2_CLASSES = json.load(open(\"ssv2_classes.json\", \"r\"))\n",
        "\n",
        "    with torch.inference_mode():\n",
        "        out_classifier = classifier(out_patch_features_pt)\n",
        "\n",
        "    print(f\"Classifier output shape: {out_classifier.shape}\")\n",
        "\n",
        "    print(\"Top 5 predicted class names:\")\n",
        "    top5_indices = out_classifier.topk(5).indices[0]\n",
        "    top5_probs = F.softmax(out_classifier.topk(5).values[0]) * 100.0  # convert to percentage\n",
        "    for idx, prob in zip(top5_indices, top5_probs):\n",
        "        str_idx = str(idx.item())\n",
        "        print(f\"{SOMETHING_SOMETHING_V2_CLASSES[str_idx]} ({prob}%)\")\n",
        "\n",
        "    return"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Next, let's download a sample video to the local repository. If the video is already downloaded, the code will skip this step. Likewise, let's download a mapping for the action recognition classes used in Something-Something V2, so we can interpret the predicted action class from our model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "sample_video_path = \"sample_video.mp4\"\n",
        "# Download the video if not yet downloaded to local path\n",
        "if not os.path.exists(sample_video_path):\n",
        "    video_url = \"https://huggingface.co/datasets/nateraw/kinetics-mini/resolve/main/val/bowling/-WH-lxmGJVY_000005_000015.mp4\"\n",
        "    command = [\"wget\", video_url, \"-O\", sample_video_path]\n",
        "    subprocess.run(command)\n",
        "    print(\"Downloading video\")\n",
        "\n",
        "# Download SSV2 classes if not already present\n",
        "ssv2_classes_path = \"ssv2_classes.json\"\n",
        "if not os.path.exists(ssv2_classes_path):\n",
        "    command = [\n",
        "        \"wget\",\n",
        "        \"https://huggingface.co/datasets/huggingface/label-files/resolve/d79675f2d50a7b1ecf98923d42c30526a51818e2/\"\n",
        "        \"something-something-v2-id2label.json\",\n",
        "        \"-O\",\n",
        "        \"ssv2_classes.json\",\n",
        "    ]\n",
        "    subprocess.run(command)\n",
        "    print(\"Downloading SSV2 classes\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Now, let's load the models in both vanilla Pytorch as well as through the HuggingFace API. Note that HuggingFace API will automatically load the weights through `from_pretrained()`, so there is no additional download required for HuggingFace.\n",
        "\n",
        "To download the PyTorch model weights, use wget and specify your preferred target path. See the README for the model weight URLs.\n",
        "E.g. \n",
        "```\n",
        "wget https://dl.fbaipublicfiles.com/vjepa2/vitg-384.pt -P YOUR_DIR\n",
        "```\n",
        "Then update `pt_model_path` with `YOUR_DIR/vitg-384.pt`. Also note that you have the option to use `torch.hub.load`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# HuggingFace model repo name\n",
        "hf_model_name = (\n",
        "    \"facebook/vjepa2-vitg-fpc64-384\"  # Replace with your favored model, e.g. facebook/vjepa2-vitg-fpc64-384\n",
        ")\n",
        "# Path to local PyTorch weights\n",
        "pt_model_path = \"YOUR_MODEL_PATH\"\n",
        "\n",
        "# Initialize the HuggingFace model, load pretrained weights\n",
        "model_hf = AutoModel.from_pretrained(hf_model_name)\n",
        "model_hf.cuda().eval()\n",
        "\n",
        "# Build HuggingFace preprocessing transform\n",
        "hf_transform = AutoVideoProcessor.from_pretrained(hf_model_name)\n",
        "img_size = hf_transform.crop_size[\"height\"]  # E.g. 384, 256, etc.\n",
        "\n",
        "# Initialize the PyTorch model, load pretrained weights\n",
        "model_pt = vit_giant_xformers_rope(img_size=(img_size, img_size), num_frames=64)\n",
        "model_pt.cuda().eval()\n",
        "load_pretrained_vjepa_pt_weights(model_pt, pt_model_path)\n",
        "\n",
        "### Can also use torch.hub to load the model\n",
        "# model_pt, _ = torch.hub.load('facebookresearch/vjepa2', 'vjepa2_vit_giant_384')\n",
        "# model_pt.cuda().eval()\n",
        "\n",
        "# Build PyTorch preprocessing transform\n",
        "pt_video_transform = build_pt_video_transform(img_size=img_size)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Now we can run the encoder on the video to get the patch-wise features from the last layer of the encoder. To verify that the HuggingFace and PyTorch models are equivalent, we will compare the values of the features."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Inference on video to get the patch-wise features\n",
        "out_patch_features_hf, out_patch_features_pt = forward_vjepa_video(\n",
        "    model_hf, model_pt, hf_transform, pt_video_transform\n",
        ")\n",
        "\n",
        "print(\n",
        "    f\"\"\"\n",
        "    Inference results on video:\n",
        "    HuggingFace output shape: {out_patch_features_hf.shape}\n",
        "    PyTorch output shape:     {out_patch_features_pt.shape}\n",
        "    Absolute difference sum:  {torch.abs(out_patch_features_pt - out_patch_features_hf).sum():.6f}\n",
        "    Close: {torch.allclose(out_patch_features_pt, out_patch_features_hf, atol=1e-3, rtol=1e-3)}\n",
        "    \"\"\"\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Great! Now we know that the features from both models are equivalent. Now let's run a pretrained attentive probe classifier on top of the extracted features, to predict an action class for the video. Let's use the Something-Something V2 probe. Note that the repository also includes attentive probe weights for other evaluations such as EPIC-KITCHENS-100 and Diving48.\n",
        "\n",
        "To download the attentive probe weights, use wget and specify your preferred target path. E.g. `wget https://dl.fbaipublicfiles.com/vjepa2/evals/ssv2-vitg-384-64x2x3.pt -P YOUR_DIR`\n",
        "\n",
        "Then update `classifier_model_path` with `YOUR_DIR/ssv2-vitg-384-64x2x3.pt`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Initialize the classifier\n",
        "classifier_model_path = \"YOUR_ATTENTIVE_PROBE_PATH\"\n",
        "classifier = (\n",
        "    AttentiveClassifier(embed_dim=model_pt.embed_dim, num_heads=16, depth=4, num_classes=174).cuda().eval()\n",
        ")\n",
        "load_pretrained_vjepa_classifier_weights(classifier, classifier_model_path)\n",
        "\n",
        "# Get classification results\n",
        "get_vjepa_video_classification_results(classifier, out_patch_features_pt)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The video features a man putting a bowling ball into a tube, so the predicted action of \"Putting [something] into [something]\" makes sense!\n",
        "\n",
        "This concludes the tutorial. Please see the README and paper for full details on the capabilities of V-JEPA 2 :)"
      ]
    }
  ],
  "metadata": {
    "fileHeader": "",
    "fileUid": "f0b70ba6-1c84-47e1-81bd-b7642f9acf50",
    "isAdHoc": false,
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.12.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 2
}
